{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5684c89e",
   "metadata": {},
   "source": [
    "## Attention based Spatial-Temporal Graph Convolution Cell\n",
    "Model that chains blocks containing a temporal and a spatial attention layer follwed by <br />\n",
    "a convolution layer on the time axis. <br /> <br />\n",
    "\n",
    "<pre>\n",
    " &#8625;  res_channels   &#10141;                                             &#8628; <br />\n",
    "&#10141; [temp_attention &#10141; spatial_attention &#10141; time_conv] &#10141; block_out + residual &#10141; ... &#10141; last_block_out <br />\n",
    "                                                                                            &#8681; <br />\n",
    "                                                                                        prediction\n",
    "</pre>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ee19d01",
   "metadata": {},
   "source": [
    "### ASTGCN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7e473b3a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [16:16<00:00,  1.02it/s]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA0/klEQVR4nO3dd5hU1f3H8fd3G0uvS+9KEVEsK2IFVBQxsTdioqgJJrFHY4stmtg10cRGjDUqNqLEhqIg/ESUjtI7LHVZ6i5sP78/7p3Zmd3Z3VnY2WV3Pq/nmWduOXPvuTtwv3POueccc84hIiLxK6G2MyAiIrVLgUBEJM4pEIiIxDkFAhGROKdAICIS5xQIRETinAKBSB1jZpPNzJnZqNrOi9QPCgRywDCz1f4N7tzazsuBwMyG+H+P1aV2vQ88DSys+VxJfZRU2xkQETCzZOdcQTRpnXP/jHV+JL6oRCB1hpmdZ2YzzGy3ma0xs2fNrIW/L8XM/mVmm8wsz8zWmdn//H1mZg/52/L8NBPMrHU552lsZo+b2QozyzazuWb2K39fTzMrNrMsM0v2t3Xzf7ln+flIMrPbzGyRmeWY2UIzGx1y/Pv99O+b2btmthe4rFQehgCT/NXA8Z2/L6xqyMxe9ddfN7PPzGyvmX3h5+sDPw/fmVmPkOP3N7NPzGyLmWX66bpWw9ckdZACgdQJZjYCGAcc7r/vBn4PjPWTXA78GtgK/BuYBRzv7zsVuBMo8vdNAQ4DmpZzuleAW/307wK9gNfNbKRzbiXwLdAKOM1Pf7H//o5zLh94EHgUMOAtIBV40cyuKHWeC4CDgDeATaX2ZQAf+Mu78aqCni4nvwG/BLKBbcAwYB7QAlgJDPLzhZm19/8Gw4D/AyYD5wMTzKxBJeeQekiBQOqK6/z3h5xzVwBDgELgDDPrDST7+38E3gSuBNr62wL7luPd2K8DOgFrS5/EzNoCF/mrw5xzVwF3+evX+++v+++X+O+BQPC6mVlIXqcBOcBP/vrvSp1uJXCsc260c+7z0B3OueVAoApom3PuJufcTaXzW8rXzrmLgH/563vxbvaB/B/pv/8KaIn391gLrAcygb7A0ErOIfWQ2gikrujuvy8CcM5tNbOtQHugG97NeQhwDnAp4ICJZnYe8AXwHN4NMFDdMhM4G9hYznn2OufW+MuL/fdu/vu7wDPAuWbWF0gHljrnpptZGtDET3dlqWMfXGr9B+dcYWUXXgWL/Pcd/vty51yxme321xv7793990P8V0V5lDigEoHUFav9974Afv1+G3/bGqDQOXcJ0Azv5jYR79fw+UAi3q/0Fng3utfxbt6/ruA8DUPqzPuEnAfn3E7gI6A5MMbfFyglbMUrBQAMcM6Zc87w/q+llzpXXiXXXOS/R/v/tKiS9YDV/vt/A/nz89gBr+pM4owCgRyIHjWz6SGvk4Fn/X13mdmrePXaScCXzrmlwEgzW4RXv38jXhsAeL+OjwdW4VUZ/QE4IWRfGOfcFrzHMwG+NLOXgYf89dCndQI3/pPwSh9v+J93IXn9wm/AfhuvGuj+Kv0VYJ3/3tnMXjKz26v4+fK8iXft5/mN5i+a2UT/fO2q6RxSh6hqSA5EvUutt3LOfWhmFwN3ABfiNYi+iNcIDLAE79f4CLxG4I3AX4CP8UoBy/AajVv46V6g5Nd8aVfh3RTPw2sHWAE85Zx7KyTNBLwG3vbAN8650PaGu4EsYBReA+4uYA7wTpTXD4BzbrWZPYFXcrkaWIDXCL1fnHMbzGww8FdgIHAiXlvBs3h/G4kzpolpRETim6qGRETinAKBiEicUyAQEYlzCgQiInGuzj011KZNG9e9e/fazoaISJ0ya9asrc65tEj76lwg6N69OzNnzqztbIiI1Clmtqa8faoaEhGJcwoEIiJxToFARCTOKRCIiMQ5BQIRkTinQCAiEucUCERE4lzcBIIlm3bz1BdL2Jpd2VwgIiLxJW4CwbItu3nm6+Vsy8mv7ayIiBxQ4iYQGAaApl8QEQkXP4HAiwM4FAlERELFTyDw31UiEBEJFz+BIFAiUCAQEQkTN4EgUCZQ1ZCISLi4CQQqEYiIRBY/gaC2MyAicoCKn0BgenxURCSS+AkE/rvaCEREwsVPIFAbgYhIRPEXCGo3GyIiB5z4CQTBISYUCkREQsVNIEAlAhGRiGIWCMzsZTPbYmY/VZBmiJnNNbMFZvZNrPICGmJCRKQ8sSwRvAoML2+nmbUAngPOds4dClwUw7wEHx9VmUBEJFzMAoFzbgqwrYIkvwDGOefW+um3xCovoBKBiEh5arONoDfQ0swmm9ksM7u8vIRmNtrMZprZzMzMzH06malrsYhIRLUZCJKAo4GzgDOAe8ysd6SEzrkxzrl051x6Wlrafp1UBQIRkXBJtXjuDCDLOZcD5JjZFGAAsDQWJ9MMZSIikdVmieAj4EQzSzKzRsCxwKJYnaykZ7EigYhIqJiVCMzsbWAI0MbMMoD7gGQA59wLzrlFZvY5MB8oBl5yzpX7qOl+58d/VxgQEQkXs0DgnBsZRZrHgcdjlYcwGmtIRCSiuOlZbJqhTEQkovgJBKobEhGJKH4Cgf+uOCAiEi5+AoFmKBMRiSiOAoH3rjYCEZFw8RMI/HeVCEREwsVPINB8BCIiEcVNIEAzlImIRBQ3gUAlAhGRyOInEAQWFAlERMLETyAw9SwWEYkkfgKB/64mAhGRcPETCDTonIhIRPETCIKDzomISKj4CQSamEZEJKK4CQQBCgMiIuHiJhCojUBEJLL4CQQaiFpEJKL4CQQqEYiIRBSzQGBmL5vZFjOrcEJ6MzvGzArN7MJY5cU7j/euOCAiEi6WJYJXgeEVJTCzROBR4IsY5sM7F5qYRkQkkgoDgZklmtnZZtanqgd2zk0BtlWS7HrgA2BLVY9fVZqYRkQksgoDgXOuCPg3cFx1n9jMOgHnAc9HkXa0mc00s5mZmZn7dj7/XSUCEZFw0VQNvQmMMrNDzaxV4FUN5/47cLtzrriyhM65Mc65dOdcelpa2j6dTG0EIiKRJUWR5ga8++f8kG0uys9WJB0Y648K2gYYYWaFzrkP9/O45dDENCIikURzM59CDH5IO+d6BJbN7FXg49gFgZISgYiIhKs0EDjnhuzLgc3sbWAI0MbMMoD7gGT/mC/syzH3h9oIREQiqzQQmFlz4GngTH/TJ8DNzrmdFX3OOTcy2kw450ZFm3ZfaWIaEZHIomksfga4HMj3X6PwGnrrFJUIREQiiyYQnAk85pzr4pzrAjwOnBXbbFU/DTEhIhLZvvQs1q1URKQeieapoU+BP5rZL/z1TsDrsctSbGiGMhGRyKIJBDfhlRwCjcVvADfHKkOxohnKREQiqzAQ+IPC3QO84py7vGayFFsKAyIi4aIZa+hc4KAayU0MmealERGJKJqqocnAvWbWANgY2OicGxerTMWC+hGIiEQWTSC40n9/xn83vN/ViTHJUYyoH4GISGTRBII/xzwXNUCjj4qIRBZNY3EzvAHhJtVMlmJDM5SJiEQWd43FaiMQEQkXP43F/rtKBCIi4eKmsRi1EYiIRBRNIHiAenD/NDTqnIhIJNFMTHN/DeQj5vTUkIhIZOU2FpvZbDMbZmaNzexlM+vrbz/PzLbVXBarh9oIREQiq+ipoSOAlkAq3mQ0Hf3tKUDzmOYqBoI9ixUJRETCVDYfQb25a2qoIRGRyCoLBFcBD+PdP68zs2fwSgeV8quTtpjZT+Xsv8zM5pvZj2Y2zcwGVCXjVaUZykREIqussfiMkOVzQ5ajuZ2+CvyT8iexWQUMds5tN7MzgTHAsVEcd59oYhoRkcgqCgRD9+fAzrkpZta9gv3TQlanA53353yV0sQ0IiIRlRsInHPf1GA+rgY+K2+nmY0GRgN07dp1n04QnI9ARETC7Mvk9dXKzIbiBYLby0vjnBvjnEt3zqWnpaXt23mCx9qnj4uI1FvR9CyOGTM7HHgJONM5lxXjcwEadE5EpLRaKxGYWVdgHPAr59zSmJ/Pf1eJQEQkXLklAjO7t6IPOuceqGi/mb0NDAHamFkGcB+Q7H/2BeBeoDXwnP9rvdA5l16VzFeFhpgQEYmsoqqh+0OWHWX7ZFUYCJxzIyvZ/2vg15Xkr9poYhoRkcgqCgQX+e9DgcHA3/Cqkm4EvotxvqqdJqYREYmsosdHPwAwsyeAvzrnXvbXDbitZrJX/VQiEBEJF81TQw2A+8ysM16J4EoOgMdOq6pBUgJmkFdQVNtZERE5oEQTCG7Fe8Qz0HicizcGUZ1iZjROSSI7T4FARCRUNBPTvGVmE4FB/qbpzrktsc1WbDRukEhOXmFtZ0NE5IASbRXPMcApwHLg9FiPFBorjVOSyM5XIBARCVVpIDCzm4D/AdcD7YHzgcdjm63YaNwgSSUCEZFSoikR3AS8F7I+ETgqJrmJMYdj8pJMpi7LrO2siIgcMKIJBC2BeSHrjYDE2GQnttZk7QHgma+W1XJOREQOHNE8NfQD8Dt/+VbgRODbmOUohlKTE9mdW0higsakFhEJiKZEcD2wF2+IieHARrzqojqnQZJ3uQoEIiIlKiwRmFki0BuvgbjY37zEOVcnH8ZPSQwEgjrXH05EJGYqvCP6N/x/A+nOuYX+q04GASA4bJ7KAyIiJaJpI3gTGGVmM/CqhQBwzm2LWa5iJFAiyC8sriSliEj8iKaO5AbgJGA+kOm/6mTP4ucu85567diiYS3nRETkwBFNiWAK9WQ+l55pTejcsiGbd+XWdlZERA4Y0Yw1NKQG8lFjUpIS+L/lW8kvLCYlSY3GIiKVBgJ//oFLgcOAVH+zc87dEsuMxcrZAzry94nLWLp5N/07Na/t7IiI1LpofhI/i9dgfDte/4HAq0Jm9rKZbTGzn8rZb2b2jJktN7P5ZlYjw1YM798eKOllLCIS76IJBOcBb/nLNwKTgAej+NyreB3QynMm0Mt/jQaej+KY+61VoxQAtu/Jr4nTiYgc8KIda2iqv7wReB/vxl0h59wUoKJHTM8BXnee6UALM+sQRX72S4tAIMhRIBARgegCwSa8toRNeDOVPRnl5yrTCVgXsp7hb4upQAPxk18uJVfTVoqIRHVDvxtYAdyCN03lTmp4rCEzG21mM81sZmZm9Q0h/cOqOtcnTkSk2lUaCJxz/3HOfe6cG+uca++c6+CcG1sN514PdAlZ7+xvi5SHMc65dOdcelpa2n6f+MZTewGQYBpsQkQkmhnKvo7w+qoazj0euNx/emgQsNM5t7GyD1WHU/q2BSCvUFVDIiLR9CweEmFbpT2Nzext/7NtzCwDuA9IBnDOvQB8CozAmwd5D3BlVDmuBqnJ3rw6uQUac0hEJJpAEFoX0xK4n5DB58rjnBtZyX4HXBvF+atdarJXELr2rdkc0/1U2jZLreQTIiL1VzSNxS7ktQtYAlwRy0zFWoOkkpk2V23NqcWciIjUvmhKBFspWxW0JAZ5qTGBEgFAUqLGGxKR+FbV0UeLgNXAE7HKUE0ItBGApq0UEYm70UehZO5igIIiNRiLSHyLZvTRlyvY7ZxzV1djfmqEhfQf0GxlIhLvoqkaGoVXNRS4e5ZernOBIJQCgYjEu2haSp/AG3TuNOB0f/lJ4BhgYOyyFluXpHudmmev3V7LORERqV3RBILLgXecc1875yYC7wIXO+dmOedmxTZ7sfObk3sA8I+vl9dyTkREalc0VUN7gYf9YSAMOBvIimmuakBKYmLliURE4kA0JYJf4wWDXwG/xBsO4texzFRNCJ2v+KtFm2sxJyIitSua0Ue/AroBR/iv7s65SbHNVuyFdh+4+rWZtZcREZFaVmEg8CeuxzmXD3QAhgGDayBfMde2WSqNUkqqh4Y99Q079xTUYo5ERGpHuYHAH2r6S3/5arzRQh8DPjezu2sme7F1cXrJdAjLtmQzeekWAK54+QcmLNhUW9kSEalRFZUI+gOf+Mu/9d8fBL4BfhPLTNWUnw8InyI5wYxduQV8szSTa96osw9EiYhUSUWBoDmQZWbNgSOBtc65+4HXgLY1kLeYO7pbK1745dHBdTNYmemNRtqmSYPaypaISI2q6PHR1XjzFF+EFzA+97d3pR48PhrQpEHJn6DYwTsz1gLQpVXD2sqSiEiNqqhEcA/QBzgLbyjqJ/3tlwLTY5yvGtOqcUpweW9+IXn+kBO92zZlRWY2/52TUVtZExGpEeWWCJxz75nZ10BPYJFzLtvMkoBfAPWmJbVD85LZyZZvyWbLrjzAG5V0xNNTySss5rwjO9dW9kREYq7CnsXOuSxCqoGcc4XAvFhnqia1aJQcXP7X1FXB5eWZ2cHSQUFRMcmawEZE6qmY3t3MbLiZLTGz5WZ2R4T9Xc1skpnNMbP5ZjYilvkpJ4+MHT2I7q0bhW2fn7EzuJxbUFTT2RIRqTExCwRmlgg8C5wJ9ANGmlm/UsnuBt51zh2J1/bwXKzyU5FBPVvz1S1DuPrEHhH33z9+IT+t3xlxn4hIXRfLEsFAYLlzbqXfM3kscE6pNA5o5i83BzbEMD8VSkww7vlZ6Tjl+WB2BpeOqTft4yIiYaKZoawPcCvQHQiMyeCcc6dW8tFOwLqQ9Qzg2FJp7ge+MLPrgcZ4cx5EysNoYDRA165dK8tyTBQ7V3kiEZE6KJoSwYd4s5CdCgwJeVWHkcCrzrnOwAjgDTMrkyfn3BjnXLpzLj0tLa2aTh3ZyociN1MEJrn/v2VbGTdbj5SKSP0RzXwErYC/4Y0zVFiFY68HuoSsd/a3hboaGA7gnPvOzFKBNsCWKpynWiWEDksaYnduId3v+CS4fv5ReqRUROqHaEoErwMHA03w6vQDr8rMAHqZWQ8zS8FrDB5fKs1avJIGZnYIkApkRpd1ERGpDtEEgluAnwFL8W7SmUTxi93vc3AdMAFYhPd00AIze8DMzg459m/MbB7wNjDKudqvjL99eN99+tz0lVl8NLd0oUdE5MBmld13zWwyEUoAzrmhMcpThdLT093MmbGfSGbw45NYk7WHNk0asDU7L2Kav19yBOcc0RF/2oZg1dHqR86Kef5ERKrCzGY559Ij7YtmhrIhzrmhpV/Vn80DywV+G8DXt5Y/D89N78zlhrFzayhHIiKxEc3jo4ZXv38YXh0+eI+P3hLLjNW26085mGsG96RBUiL/u+5Etu/J56Z35rItJz8s3f/mbeCgtMbccEqvWsqpiMj+ieapoWfxJqZxQOCRGodXv19vmRkNkrxuE4d1bg5Avj/2UGl/n7iMv09cVmN5ExGpTtE0Fp8HvOUv3whMwpupLO5k51Xl6VlPbkERD3+2iKxy2hlERGpbNIGgJTDVX94IvI/fy1cq98q3q3nxm5W89f3a2s6KiEhE0VQNbfLTbQJeAlKAXbHMVF3nnCMzO4+s7HyWb8kGoGXIBDgiIgeSaEoEdwMr8NoEcoGdwE0xzNMB65J0r6N0ZY+HFhU7Rjw9lTOfnhocwvruD39ibdaemOdRRKSqKi0ROOf+A2BmLYBuzrm4rex+9MLDefTCwytNl5NfxNZs7+mi2Wu3B7e/Mm0Vh7RvxsXHdCnvoyIiNa7SEoGZdTezGXjzFp9kZt+Y2QOxz1rdcXq/dmHrf5+4NLi8cWducPmVb1dz2wfzGf167DvEiYhEK5qqoRfwhpQ2oBiYgtevIK49dN5hvHR5OqsfOYsnLh4Qtu+Vb1dX+NkvFm5m/LwNwWqjRz5bzKVjvgvuzy0oYuBfJzJx4eZqz7eISGnRBILjgX+GrK/AG0k0rv3i2K6c5pcEmqUms/qRs7hu6MFRf/6Gt+dw6H0TyMrO44VvVjB95Ta2+53VNu3MZcvuPB74eGFM8i4iEiqaQLAV6O8vt8UrDdTaTGIHsk27csPWj+7WssL0RcWOeRk7guu/KVVl5KIa5FVEZP9EEwj+hXfzN+BNYBjwYiwzVVed2b99cHnuvcN49ILDOPeIjhV+ZsWWnODy0s27Acgv8nowVzYO67pte9idWxBcPgAGbhWROiiaQeceBq7E60j2AXClc+7xWGesLjr1kHasengEix8cTotGKRzctil/v/RI3v/tceV+ZtqKrcHlXbmFnPLEZBZs2AlUHghOemwSFz7/HbPWbOekxyYxdsa6ij8gIhJBVJPXO+dec85d7L9ej3Wm6jIzIzU5MWxbevdWweUGSeF/8klLwufhWbk1h09/3AQQ1S/8JZt3s8wvScxes72S1CIiZZXbj8DMiir4nHPORdMrWXy3nt6bQT1bk5qcyJvfr+HtH8r/9f6l/7RQRWEgNEgElhIs8jSbIiIVqehmbnj3mA3AjhrJTT12Xcgw1Q+ffzh3jTiEez9awH/nlD+jWeBen1tQxIINu8IanwuLSwJBsZ8wIarynYhIuIpuHa8AOXiTyf8I/ME5d1jgVSO5q8eapibTu11TAM44tF3ENDv3eg3Bd437kQuen8ZHc9ezJstrXM4LGRJ79dZAg7NKBCJSdeUGAufc1UAH4PdAF+BzM1ttZsNrKnP13cXpnTmiSwtuPb0Pr181sMz+vQVF9Ln7M8b5pYYbx85l8OOTeev7tfS/b0Iw3b+mrgIgoZI4kF9YzB/enasxj0QkTIWVCc65HGAlsArIxysdNI324GY23MyWmNlyM7ujnDQXm9lCM1tgZm9FSlNftW7SgA+vPYFe7Zpy3EGtOf+oTnx1y2Dm3Xt6ME1ehMlw7vrvjxGPF2gj2JNfyLptZW/236/KYtzs9dz53/nVdAUiUh9U1Fj8J2AU0BP4HrgeeMc5tzuaA5tZIt7sZsOADGCGmY13zi0MSdMLuBM4wTm33cza7uuF1HXJiQk8dfERwfUurRqybtveKh0jUCK46tUZTF+5rcwoqeZXHam7gYiEqqhE8CBeEFiJ17v4bOBNMxtvZh9FceyBwHLn3ErnXD4wFjinVJrfAM8657YDOOe2VPUC6qsXf5nOlSd0r9JnzC8RTF+5DYCCovDSRCBQFCsSiEiIyh4BNeAg/xUqmjtJJyD0GckM4NhSaXoDmNm3QCJwv3Pu8zKZMBuNPyta165dozh13devYzPu63goww5pR15hMVe+OqPSz5jBJ/M3Btf35BfRvGFIrPcDQSAO7NxbQIJ5DdciEr8qCgQ9auj8vYAheAPZTTGzw5xzO0ITOefGAGMA0tPT4+rn7PEHtwk+PVQZw7j2rdnB9T35hTRvWHKTLyjy/nSBQDDgz1+QnGgs++uI6suwiNQ55QYC59ya/Tz2erynjQI6+9tCZQDfO+cKgFVmthQvMFT+8zeONG+YzNTbhjJzzTbSu7WiqNiRlGj84l/fszakUbiwOLwqaMjjk8krLOa1qwYyuHcaef6w16GD2QWCQ2l5hUWkJCYEq5tEpP6KZRekGUAvM+thZil4A9eNL5XmQ7zSAGbWBq+qaGUM81RndWnViPOO7EyXVo3o3qYxnVs2YsptQ/nLuf2DaT6aGz4obOCJoyte/oG9+UWMmeL9aRdv3M27FYxLtHNPAX3u/jyYXkTqt5gFAudcIXAdMAFYBLzrnFtgZg+Y2dl+sglAlpktBCYBf3TOZcUqT/XRLwd1Cy5XVIU09InJzPTHItqdV8htH5T/CGlmtjec9jNfLQvbvmNPPnd/+CN78ysafURE6pqYjhfknPsU+LTUtntDlh3wB/8lMVR6roRQs9du56iu3vAVa7JymLjIe3grJ7+IvMIiGiR5g+j94+vl/Gf6Wvq0a8qvjute6TnXbdtDflExB6U12f8LEJGY0cBx9ci5R3RkeP8OPPXlEvq2b8ZNp/Xi7H9+S3ZeYYWfO/+5aYw6vjtfLNjEhp3hAWPHngLaNfMCQeCx06e/WsYph7SjU4uGgFeVNOCBL7j0mC48csHhwc+e9NgkgDL9GUTkwKJAUA/MuWcYiYlGM/8x0OEhE+TMvmcYm3flcv7z0xjQuTkjB3bl6tdmljnGq9NWRzx2xvY9bNyZS1FxMcmJXk3i1ux8bh47l3f9eRb+NdVrSxg7Yx0PnXcYCZWNdSEiBxQFgnqgZeOUcvelJCXQpVUjZvzptOC2swd0ZPy86GYbveD574LLvx9S0p3kh9XbWLBhJ4d2bM4/Jy0Pbt+ak0fbpqlhx1iRmU2Lhsm0btIgOHy2nkYSOXBo4OI4dM3gnrRp0oCbT+vNiQe3iZgmKcKv+tLzHdw0dm6ZyXM27CjbFnHqk99w+t+msHNPAT3u/JQed35aJk1F8guLyY8w5pKIVA8Fgjh0aMfmzLz7NG48rRevXTWQC47qDMBjFx4eXP74hhM56/AOYZ8L/eUPsGxLdpmb+iUvfsezk5aXuXFn5eQze23JDGpfLNjEsKe+ieoJpJMe+5qjHvwy+gsUkSpRIIhziQnGkxcPYPUjZ3FxeheevHgAqx4eQd/2zXjq4gERP/PsL45iYI9WEfflFRbz+IQl7NiTX3ZnSIFi9BuzWLYlO9ghbueeAr5evDniMTfvyquwwXt+xg6e+nJpxH0LN+yisEilCZGKKBBIGYH6+wZJiXx588m8/ZtBwX1z7x3GWYd34JHzK56baOBDX5XZduUrZTuMZ2XnMW52BpeM+Y6rXp3JD6u2cebTU3nz++g7tp/9z2955qtlFBWHV1Ot3prDiGem8shni6M+lkg8UmOxVKhXu6b0agfjfn88ExdupkUjr2G6Z1oTzuzfns9+2sQ/Rh7Jzwd05M5x8yucizmSjB17ue39ks5tU5dlsmjjLv70359YsGEXbUIawguKipmfsYM3p6/lrrMOoU2TBmHHyissolFKyT/prByvVDJjzXZEpHwKBBKVo7q2DHY6C7jzzEPILShiaF9vGonQm3C0XpoaPozFP74uaYd46/u1Yfu+XryFa96YBcC4Oeu5bujB3HpGn+D+ZZuzeXfmOi48ujPz1u3gCD+/xcVxNU6hSJUpEMg+69q6Ea9cWTLF5s3DenN45+bcOHYuAM+MPJJxszOYvCSTUcd359COzVi2JZsxU1byh2G9aZicyF8/XRT1+QJBIOCfk5bTv1Pz4PoNY+ewJmsPb/oB5D2/n0PpKqOKvD8rg34dmtGvY7OoPyNS1ykQSLVp0iCJc47oRLOGyezOLeTsAR352WEdKCguDg5TAfDrk3rQtmkqO/bkhwWCy4/rxuvfVW3Q29/+pyQ4rCk1F/NGv5d0ZRPxZGXn0TQ1maQE49b35gHqDS3xRYFAqt3QPiUzjiYkGA0SEsP2BzqctWiUwo2n9uLpr5bxwe+Op3nDJF7/bg0jB3ahYXISL3+7ar/yEZi3OTQQ7NxTQHKSsWFHLuu27+GlqSv5dnkWZxzajgfPKRnJNTuvkCYN9N9D4oP+pUutunlYby4b1DUYHBY/OJzU5ESKih17C4r475wMnrzoiLAJdwLu+3k//vy/hWW2Bzw+YQkQOiGPY8ADX0RMO2HB5rChOfrfN4E/n30oJ/dOo0ebxlW6pvS/TGRYv3Y87D9ZlVdYxIdz1tOyUQopSQkM6RO3U3PLAcpK9ww90KWnp7uZM8uOlSP123/nZHDzO161zeDeaTRukMhfzj2sTEeza4cexCvfrmZPqY5qS/9yJnPWbueSMdOrfO7Lju3Kuu17ee3KY8odGmNPfiHJiQk4B73v/gwoqV566sulYUN6q9pJaoOZzXLOpUfapxKB1AnnHdmZo7q2ZMOOXI47qHVw+x/P6BP85e+t9+WPZ/Tlipd/4JulmcHtd3/4I+/OzNincwcanzftyqVD84YR0/S7dwIn907j3p8dUmbf5p3lDwEuciBQhzKpM7q1bhwWBACuHXowp/ZtS9dWjXj3muOC21+7aiAf/O44ju7mPUIaCAJ92zcN+3xa0/C+CBW56IXvgpP/LNu8m+53fMIb09cEx1uasjSTuet2BtNnZecBMGPNtqjPIVIbFAikzvv3qGOYctvQMsNeHN2tFR/87vjg+jHdW/LJDSfx05/PAGD0yT1p38xrmxg7ehCVydi+l5/9YypbduUy7G9TALjnw59Yv2NvME3gqSOAC1/wRm5dmZkTdpxef4o86F5hUbH6PEitUNWQ1Hv9OjRj4cZdvPdbLyg0aZDE4geHk5KYwOJNuxk3O4MjurQIpu/YPJUurRpxUNsm/OaknuzYk881b8xiy+481m3bW2b4jBMfnRTxvKu25kScPrSgyFFYVExSYvjvsF53f8bJvdJ47aqBZT5Tl/xn+hqe+nIp/7v+xODkRXJgUyCQeu+dawaxKzd80LrUZO+R1n4dm9GvYz8A3r3mOFo1TuagtCalGoUbM6RPWpk2hk4tGoaVBiJ5euKyiNt37C3gx/U7aZaaRP9OzTn32Wk4R1i7BsDEhZv5x9fLOLpbK+79eb9oLheAy1/+gTlrtvOjX/qpKf+3bCt3f/gTAB/P28A1gw+q5BNyIIhp1ZCZDTezJWa23MzuqCDdBWbmzCxii7bI/miamhzVL9OBPVpxcNumEZ8MOqRD2Z7GH157An8a4TUOtw1pa7jm5J7B5fL6QqT/ZSJXvjKDC57/jg07clm0cVdw3xvfraa42LF6aw6/fn0m8zJ28vK3q3jl21V8PD/yhELj521g8aaSY0xZmsnuSqYoDdidW1BtI7Rm5eQFlx/WYH91RsxKBGaWCDwLDAMygBlmNt45t7BUuqbAjcD3scqLyP761aBuJCcm0CApgYUbd9G5ZSPSmjbgNyf35JeDupGQ4LUh7M0vone7prw4ZWWZY6QkJpAf4YY79InJYev3fLSAez5aUCZdoM/E6f3as2lnLquzcji5dxp78gu54e05APztkgFhTzYVFBWzJ7+I5g2TI15XcbHjsPu/YOTALjx8/uER01RFSqKaHeuiWFYNDQSWO+dWApjZWOAcoHQPoAeBR4E/xjAvIvslKTGBXw7qFnFfwxSvmumgtCbBbdPuOIXjH/kagMYpiYy//kQ6Nm9IcqLx7KQV/G1i5PkTonHyY5PYtCvyI6mBvhYBt743j4/mbmDVwyPCSjo79xTw3cqtBOLSO/5800s276ZPu/BS0aTFW0hMME7unVZp3hJLzWyXX1hMSpKCw4EuloGgExA6JnEGcGxoAjM7CujinPvEzMoNBGY2GhgN0LVr1xhkVaR6dWzRkPn3n86s1duDo7MGXH/KwcFA8O41x7G3oIjXpq2mW+tGfDJ/I1t250U6ZFB5QSCSj+Z6VUkfz9/IlKWZvDcrcl+K5MQEZq/dwQXPT+P6Uw7mltNLRnW98lVvHonPbjyJ9s1SK5wjO9CLO2Dzrlx+Wr+ThRt3hR1TDiy11lhsZgnAU8CoytI658YAY8DrWRzbnIlUj2apyWWCAHjjL93/834UFrvgI6+D/V/blx3bjc27cjm0YzMSE4xJSzLZsiuX5yavoFlqEqtDBtb7w7DefL14C3PX7ag0L9f7VUflMYMNfsP3N0szgzft0JFbz3x6Kj3TGvP1LUMA+Gjueo7o0oJurUuG4MgtCO/RvWHHXn73pjc8yPWn9FLp4AAVy0CwHugSst7Z3xbQFOgPTPaLoe2B8WZ2tnNOY0hIvTbqhB4Rtx/ctgkHty2pYjp7QEcv/fHdSTBj3Jz1wb4KVxzfneuGHkzPuyL3S6iK3ILiYLAoKHLk5BXyl08WMqBzi7B0KzNzWL9jL8XFLjjc+Hu/PY5junsBLbcwPBBMWVbyFNTSzbvp1roRTVOT/XMWkWAWMTgUFzue+nIpR3VrwSl92+339UnFYjbWkJklAUuBU/ECwAzgF865sq1gXvrJwK2VBQGNNSTxbtqKrRzVtWXwEdjduQVs2pnLuDnreX7yCm4+rTcNkhPYtbeA5yavCPts09QkXh51DBt27A3eyM86vAOfzN8Ynq5BUtRPHUHJ+En//r9VPPhx5IEArziuG699t4YXfnk0w/u355QnJ5NfWMyjFxxO3/ZNae3POOeco8edn4YdO6+wiPXb9zJ+3gauP6UXa7ftYd22PVG1W4inVsYacs4Vmtl1wAQgEXjZObfAzB4AZjrnxsfq3CL12fEHtQlbb5qaTNPUZG4f3pfbh/cNbt+TX0hqciI/H9Ax+GTSj/d7/Qq27PbaGQb3TuOPp/dhxqptYW0TVQkCAN3v+ITBvdOC/SAWPzicU56YzIaQcZZe8+eauH/8grB5JC57qeSBwa9vGVymz4dzjitfmcG0FVkAHNGlBaNfn0V+UTGf3HAih3b0JifKySukyDmapZY8IfX5T5v47KeNPHHRAJJLPdGUk1dIYoKRmpzIlws3M3P1Nm49o0+ZdPFAo4+KxIE1WTnkFhTTJ2SspeJiR0LIUz7OOXblFjLgz+FDdT924eEc3a0lpz75TZnjjj65J2NKPSqbkpjAkr8M5/YP5vPuzAzeGT2IJ79Yyg+r923MpYUPnEG/eycE128f3pdHP/f6KCQlGLPuGcZ7M9fx4pSVZO7O49MbTqJfx2as3prDkJBHc6fdcQodQ/qTdL/jE3q3a8KEm04OlkD+el5/Ljs28tNh1ck5R35R+IRNsabRR0XiXGiDbkBCqUc9zYzmDZOZfOsQOrRIZfaaHTT1ez4DzL13GBMWbOK5ySv4/MaT2bm3gOYNk1m/Yy+79hYwddlWmqYm8f1dp2JmPHBOf0Yd34N+HZtx42m9eG7ycuas3VFmiPBIXr3yGF6dtprJSzI5rVQACgQBgAZJCWUC14hnprL4weFhQQBgyebdwUDwgf/01NLN2WzeVVISCl2O5INZGZzUqw1t/TGq9tXfJy7j6a+WBeffqG0qEYjIfnPOsTprT6WT+GzPyedIfw4JMwi9/ZzUqw03nNqLbq0a0bZZKp/M3xhxQqKAQzs2Y8GGXeXuL+3Bcw7lV8d1J3N3Hsf8dWLENOce0ZGkxASWbclm3rodfHTtCfTv1Jw3v1/D0D5tOekxb1ypSHNKfDR3PRnb9/L7IQeVO29FwNEPfklWTj5T/jiUrq0bRX0N+6OiEoECgYjUqF25BeQVFJPWtAEFRcXs3FvA/eMX8LshBwXr+wGWb9nNaU9NiXiMPwzrTeeWDfnDu94TVGcd1oFpK7ayO7eQwn0cwbVxSiI5pUorvx18EEd0ac5v/xMekK4Z3JM7zzyEddv2kNa0AZMWbwk+JvurQd24+sQedK8gKB770EQ278qjVeMUZt8zrMJ87ckv5OpXZzIvYwcLHxi+T9cGqhoSkQNIs9Rk8GtWkhMTaNOkAf/8xVFl0h2U1oTrhh7M0d1a8sPqbZx1WAd+XL+Tji0aMrh3Gss27wbguJ6tefayo3DO8fiEJTw3eQVnD+jIqBO6c/5z06LO10m90vh8waawbTv25JcJAgAvfrOSoX3acmmEGe/emL6GN6avYezoQQzq6c2fsWlnLv+bt4HM7DymLtsarILalpMf7HuRmGB8PH8D5wzoFFZtd/M7c/lupddQ7pyrtLSxL1QiEJE6yTnHWz+sZVi/dsE5r3PyCpm7bgdHd/Mer527bgdrsnL437yNTFy0mfd+exzjZmfw7fIsHjrvMOZl7ODUQ9qyLTufXu2alltlFK3DOzdnfkbJ5EStGqfw4q+O5vq35lTaI7x3uyYs3ZzNPT/rx+qtOZjB/T8/lEPu/Zy8Qm8skPn3nx72VFRVqGpIROJadl4hRcWu3MH3Au7+8Ef+M30trRqncPOw3tzjD6l9Uq82XHZsV9K7t+L7ldt46NNFwSHIRw7swozV2+nbvilD+rQNm5yoPP07NeOMfu158suKx5xKSrCwqq5/jDySn/udDKtKVUMiEteaNIjuVnf3Wf04uVcapx/aHoCLju5cpvfzWYd34Mz+7bnilR+48OjOnHNEp+A+5xxtmqTQolEKj362mF7tmvC6338iIPCIqnOO7LxCXpyykkYpiRGfpgoEgTZNUtianc/ERZv3ORBURCUCEZEYyS0oou89nwNw55l9WZGZzV0jDqFFo5SwNJOXZIZ1sivtmZFHsi07j0sHdt3nx01VIhARqQWpyYnc87N+HNW1BUd2bVlumiF90vj1iT246sQezFi9jTP7d2BbTj7bcvLp3qYRjVJie6tWiUBEJA5UVCKIv0E1REQkjAKBiEicUyAQEYlzCgQiInFOgUBEJM4pEIiIxDkFAhGROKdAICIS5+pchzIzywTWVJowsjbA1mrMTl2ga44Puub4sD/X3M05lxZpR50LBPvDzGaW17OuvtI1xwddc3yI1TWrakhEJM4pEIiIxLl4CwRjajsDtUDXHB90zfEhJtccV20EIiJSVryVCEREpBQFAhGROBc3gcDMhpvZEjNbbmZ31HZ+qoOZdTGzSWa20MwWmNmN/vZWZvalmS3z31v6283MnvH/BvPN7KjavYJ9Z2aJZjbHzD7213uY2ff+tb1jZin+9gb++nJ/f/dazfg+MrMWZva+mS02s0Vmdlx9/57N7Gb/3/VPZva2maXWt+/ZzF42sy1m9lPItip/r2Z2hZ9+mZldUdV8xEUgMLNE4FngTKAfMNLM+tVurqpFIXCLc64fMAi41r+uO4CvnHO9gK/8dfCuv5f/Gg08X/NZrjY3AotC1h8F/uacOxjYDlztb78a2O5v/5ufri56GvjcOdcXGIB37fX2ezazTsANQLpzrj+QCFxK/fueXwWGl9pWpe/VzFoB9wHHAgOB+wLBI2rOuXr/Ao4DJoSs3wncWdv5isF1fgQMA5YAHfxtHYAl/vKLwMiQ9MF0dekFdPb/g5wCfAwYXm/LpNLfNzABOM5fTvLTWW1fQxWvtzmwqnS+6/P3DHQC1gGt/O/tY+CM+vg9A92Bn/b1ewVGAi+GbA9LF80rLkoElPyjCsjwt9UbflH4SOB7oJ1zbqO/axPQzl+uL3+HvwO3AcX+emtgh3Ou0F8Pva7gNfv7d/rp65IeQCbwil8d9pKZNaYef8/OufXAE8BaYCPe9zaL+v09B1T1e93v7zteAkG9ZmZNgA+Am5xzu0L3Oe8nQr15RtjMfgZscc7Nqu281KAk4CjgeefckUAOJdUFQL38nlsC5+AFwY5AY8pWodR7NfW9xksgWA90CVnv7G+r88wsGS8IvOmcG+dv3mxmHfz9HYAt/vb68Hc4ATjbzFYDY/Gqh54GWphZkp8m9LqC1+zvbw5k1WSGq0EGkOGc+95ffx8vMNTn7/k0YJVzLtM5VwCMw/vu6/P3HFDV73W/v+94CQQzgF7+EwcpeI1O42s5T/vNzAz4N7DIOfdUyK7xQODJgSvw2g4C2y/3nz4YBOwMKYLWCc65O51znZ1z3fG+x6+dc5cBk4AL/WSlrznwt7jQT1+nfjk75zYB68ysj7/pVGAh9fh7xqsSGmRmjfx/54Frrrffc4iqfq8TgNPNrKVfkjrd3xa92m4oqcEGmRHAUmAF8Kfazk81XdOJeMXG+cBc/zUCr270K2AZMBFo5ac3vKenVgA/4j2RUevXsR/XPwT42F/uCfwALAfeAxr421P99eX+/p61ne99vNYjgJn+d/0h0LK+f8/An4HFwE/AG0CD+vY9A2/jtYEU4JX8rt6X7xW4yr/25cCVVc2HhpgQEYlz8VI1JCIi5VAgEBGJcwoEIiJxToFARCTOKRCIiMQ5BQKJe2bW3cxcqdeOGJznfv/YF1aeWqTmJFWeRCRuzAEe85fzazMjIjVJJQKREpl4HXgmAl+Z2Sj/F/x/zGyumW01s1sDic3sN/747zlm9oOZnehvTzGzh81sjZntNbMppc4z1Lx5BTLN7CL/Myf4Y8zn+tvfrqmLFlEgEClxOl4wyKSkWz/AULyx3zcBj5vZADM7BW8i8UzgD0BXYLyZtcYbEO4OYAFwHTC71HlO9Y/XHHjE33YbXq/Za4EH8IZRFqkRqhoSKfE9cLe/vB04zF9+2Tn3opkVAi8Bg/Fu/AD3Oee+NLOuwF14EwT9HG/oj0ucc7sjnOcp59wYM/sd3iQj4A0n8DO8IUJm4w0lIFIjVCIQKbHVOTfRf4UOc22l3kO5Uu/R2Oa/F1Lyf/B24Hy8gHA1MNPMWlThmCL7TCUCkRIdzezSkPVk//1KM1uLN3UiwDd4A4PdAvzZzA7CnyoRmA78D0gH3jGz94HDnXM3VXLuO4E8vOqkdXjj8DcDduznNYlUSoFApMSReKNBBtzsv38F/B5oD/zROTcPwMxG49XtP4U3RPLNzrksM3sEaAhchjdfwg9RnLsYuN4/RxZwr3Nu7X5fkUgUNPqoSDnMbBTwCt7N/4lazo5IzKiNQEQkzqlEICIS51QiEBGJcwoEIiJxToFARCTOKRCIiMQ5BQIRkTj3//VLFVNqn829AAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train MSE: 0.3109\n",
      "Test MSE: 1.1823\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch_geometric_temporal.nn.attention import ASTGCN\n",
    "from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader\n",
    "from torch_geometric_temporal.signal import temporal_signal_split\n",
    "\n",
    "from torch_geometric.data import DataLoader\n",
    "\n",
    "loader = ChickenpoxDatasetLoader()\n",
    "\n",
    "lags = 10\n",
    "stride = 1\n",
    "epochs = 1000\n",
    "batch_size = 32\n",
    "\n",
    "dataset = loader.get_dataset(lags)\n",
    "\n",
    "sample = next(iter(dataset))\n",
    "num_nodes = sample.x.size(0)\n",
    "edge_index = sample.edge_index\n",
    "\n",
    "train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.4)\n",
    "\n",
    "train_loader = DataLoader(list(train_dataset), batch_size=batch_size, shuffle=True)\n",
    "test_loader = DataLoader(list(test_dataset), batch_size=batch_size, shuffle=False)\n",
    "\n",
    "### MODEL DEFINITION\n",
    "class AttentionGCN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(AttentionGCN, self).__init__()\n",
    "\n",
    "        self.attention = ASTGCN(\n",
    "            nb_block=3,\n",
    "            in_channels=1,\n",
    "            K=3,\n",
    "            nb_chev_filter=16,\n",
    "            nb_time_filter=16,\n",
    "            time_strides=1,\n",
    "            num_for_predict=1,\n",
    "            len_input=lags,\n",
    "            num_of_vertices=num_nodes,\n",
    "            normalization='sym',\n",
    "            bias=True,\n",
    "        )\n",
    "\n",
    "    def forward(self, window):     \n",
    "        x = window.x.view(-1, num_nodes, 1, lags)    \n",
    "        x = self.attention(x, edge_index)        \n",
    "        return x.squeeze(1).permute(0,2,1).flatten()\n",
    "    \n",
    "model = AttentionGCN()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "### TRAIN\n",
    "model.train()\n",
    "\n",
    "loss_history = []\n",
    "for _ in tqdm(range(epochs)):\n",
    "    total_loss = 0\n",
    "    for i, window in enumerate(train_loader):\n",
    "        optimizer.zero_grad()\n",
    "        y_pred = model(window)\n",
    "        \n",
    "        assert y_pred.shape == window.y.shape\n",
    "        loss = torch.mean((y_pred - window.y)**2)\n",
    "        total_loss += loss.item()\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    total_loss /= i+1\n",
    "    loss_history.append(total_loss)\n",
    "\n",
    "### TEST \n",
    "model.eval()\n",
    "loss = 0\n",
    "with torch.no_grad():\n",
    "    for i, window in enumerate(test_loader):\n",
    "        y_pred = model(window)\n",
    "        \n",
    "        assert y_pred.shape == window.y.shape\n",
    "        loss += torch.mean((y_pred - window.y)**2)\n",
    "    loss /= i+1\n",
    "\n",
    "### RESULTS PLOT\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "x_ticks = np.arange(1, epochs+1)\n",
    "ax.plot(x_ticks, loss_history)\n",
    "\n",
    "# figure labels\n",
    "ax.set_title('Loss over time', fontweight='bold')\n",
    "ax.set_xlabel('Epochs', fontweight='bold')\n",
    "ax.set_ylabel('Mean Squared Error', fontweight='bold')\n",
    "plt.show()\n",
    "\n",
    "print(\"Train MSE: {:.4f}\".format(total_loss))\n",
    "print(\"Test MSE: {:.4f}\".format(loss))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d25258e0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
